{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "### Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from IPython import display\n",
    "\n",
    "plt.style.use('seaborn-white')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Read and process data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "下载数据集[百度云盘Datasets](https://pan.baidu.com/s/1gAFZ9gSf4pHJBt5W6_PgPQ \"提取码: gxk4\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = open('/home/lidong/Datasets/ML/rnn/shakespeare_input.txt', 'r').read()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Process data and calculate indexes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data has 4573338 characters, 67 unique\n",
      "{'j': 0, 'n': 1, 'f': 2, 'v': 3, 'F': 4, 'c': 5, 'Y': 6, 'R': 7, 's': 8, 'o': 9, '\\n': 10, 'X': 11, '[': 12, 'g': 13, 'G': 14, 'e': 15, 'P': 16, 'm': 17, 'r': 18, '.': 19, 'q': 20, 'O': 21, 'x': 22, '&': 23, '!': 24, 'B': 25, 'C': 26, 'D': 27, 'T': 28, 'W': 29, 'V': 30, 'd': 31, 'y': 32, '?': 33, ';': 34, 'M': 35, '3': 36, ' ': 37, 'J': 38, 'h': 39, '$': 40, 'a': 41, 'E': 42, 'Z': 43, 'k': 44, ',': 45, 'u': 46, 'p': 47, '-': 48, 'S': 49, ':': 50, 'L': 51, 't': 52, 'U': 53, 'i': 54, 'w': 55, 'A': 56, 'Q': 57, 'K': 58, 'I': 59, 'N': 60, 'l': 61, 'H': 62, \"'\": 63, 'z': 64, ']': 65, 'b': 66}\n"
     ]
    }
   ],
   "source": [
    "chars = list(set(data))\n",
    "data_size, X_size = len(data), len(chars)\n",
    "print(\"data has %d characters, %d unique\" % (data_size, X_size))\n",
    "char_to_idx = {ch:i for i,ch in enumerate(chars)}\n",
    "idx_to_char = {i:ch for i,ch in enumerate(chars)}\n",
    "print(char_to_idx)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Constants and Hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "H_size = 100 # Size of the hidden layer\n",
    "T_steps = 25 # Number of time steps (length of the sequence) used for training\n",
    "learning_rate = 1e-1 # Learning rate\n",
    "weight_sd = 0.1 # Standard deviation of weights for initialization\n",
    "z_size = H_size + X_size # Size of concatenate(H, X) vector"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Activation Functions and Derivatives\n",
    "\n",
    "#### Sigmoid\n",
    "\n",
    "\\begin{align}\n",
    "\\sigma(x) &= \\frac{1}{1 + e^{-x}}\\\\\n",
    "\\frac{d\\sigma(x)}{dx} &= \\sigma(x) \\cdot (1 - \\sigma(x))\n",
    "\\end{align}\n",
    "\n",
    "#### Tanh\n",
    "\n",
    "\\begin{align}\n",
    "\\frac{d\\text{tanh}(x)}{dx} &= 1 - \\text{tanh}^2(x)\n",
    "\\end{align}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sigmoid(x):\n",
    "    return 1 / (1 + np.exp(-x))\n",
    "\n",
    "\n",
    "def dsigmoid(y):\n",
    "    return y * (1 - y)\n",
    "\n",
    "\n",
    "def tanh(x):\n",
    "    return np.tanh(x)\n",
    "\n",
    "\n",
    "def dtanh(y):\n",
    "    return 1 - y * y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Param:\n",
    "    def __init__(self, name, value):\n",
    "        self.name = name\n",
    "        self.v = value #parameter value\n",
    "        self.d = np.zeros_like(value) #derivative\n",
    "        self.m = np.zeros_like(value) #momentum for AdaGrad"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use random weights with normal distribution (`0`, `weight_sd`) for $tanh$ activation function and (`0.5`, `weight_sd`) for $sigmoid$ activation function.\n",
    "\n",
    "Biases are initialized to zeros."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Parameters:\n",
    "    def __init__(self):\n",
    "        self.W_f = Param('W_f', \n",
    "                         np.random.randn(H_size, z_size) * weight_sd + 0.5)\n",
    "        self.b_f = Param('b_f',\n",
    "                         np.zeros((H_size, 1)))\n",
    "\n",
    "        self.W_i = Param('W_i',\n",
    "                         np.random.randn(H_size, z_size) * weight_sd + 0.5)\n",
    "        self.b_i = Param('b_i',\n",
    "                         np.zeros((H_size, 1)))\n",
    "\n",
    "        self.W_C = Param('W_C',\n",
    "                         np.random.randn(H_size, z_size) * weight_sd)\n",
    "        self.b_C = Param('b_C',\n",
    "                         np.zeros((H_size, 1)))\n",
    "\n",
    "        self.W_o = Param('W_o',\n",
    "                         np.random.randn(H_size, z_size) * weight_sd + 0.5)\n",
    "        self.b_o = Param('b_o',\n",
    "                         np.zeros((H_size, 1)))\n",
    "\n",
    "        #For final layer to predict the next character\n",
    "        self.W_v = Param('W_v',\n",
    "                         np.random.randn(X_size, H_size) * weight_sd)\n",
    "        self.b_v = Param('b_v',\n",
    "                         np.zeros((X_size, 1)))\n",
    "        \n",
    "    def all(self):\n",
    "        return [self.W_f, self.W_i, self.W_C, self.W_o, self.W_v,\n",
    "               self.b_f, self.b_i, self.b_C, self.b_o, self.b_v]\n",
    "        \n",
    "parameters = Parameters()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Forward pass\n",
    "\n",
    "![LSTM](http://blog.varunajayasiri.com/ml/lstm.svg)\n",
    "\n",
    "*Operation $z$ is the concatenation of $x$ and $h_{t-1}$*\n",
    "\n",
    "**矩阵element-wise multiple 哈达玛乘积**, 用上图$\\otimes$表示. 代码中矩阵相乘直接使用`*`符号, 向量之间相乘使用`np.dot`点积.\n",
    "\n",
    "#### Concatenation of $h_{t-1}$ and $x_t$\n",
    "\\begin{align}\n",
    "z & = [h_{t-1}, x_t] \\\\\n",
    "\\end{align}\n",
    "\n",
    "#### LSTM functions\n",
    "\\begin{align}\n",
    "f_t & = \\sigma(W_f \\cdot z + b_f) \\\\\n",
    "i_t & = \\sigma(W_i \\cdot z + b_i) \\\\\n",
    "\\bar{C}_t & = tanh(W_C \\cdot z + b_C) \\\\\n",
    "C_t & = f_t * C_{t-1} + i_t * \\bar{C}_t \\\\\n",
    "o_t & = \\sigma(W_o \\cdot z + b_t) \\\\\n",
    "h_t &= o_t * tanh(C_t) \\\\\n",
    "\\end{align}\n",
    "\n",
    "#### Logits\n",
    "\\begin{align}\n",
    "v_t &= W_v \\cdot h_t + b_v \\\\\n",
    "\\end{align}\n",
    "\n",
    "#### Softmax\n",
    "\\begin{align}\n",
    "\\hat{y_t} &= \\text{softmax}(v_t)\n",
    "\\end{align}\n",
    "\n",
    "$\\hat{y_t}$ is `y` in code and $y_t$ is `targets`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def forward(x, h_prev, C_prev, p = parameters):\n",
    "    assert x.shape == (X_size, 1)\n",
    "    assert h_prev.shape == (H_size, 1)\n",
    "    assert C_prev.shape == (H_size, 1)\n",
    "    \n",
    "    z = np.row_stack((h_prev, x))\n",
    "    f = sigmoid(np.dot(p.W_f.v, z) + p.b_f.v)\n",
    "    i = sigmoid(np.dot(p.W_i.v, z) + p.b_i.v)\n",
    "    C_bar = tanh(np.dot(p.W_C.v, z) + p.b_C.v)\n",
    "\n",
    "    C = f * C_prev + i * C_bar\n",
    "    o = sigmoid(np.dot(p.W_o.v, z) + p.b_o.v)\n",
    "    h = o * tanh(C)\n",
    "\n",
    "    v = np.dot(p.W_v.v, h) + p.b_v.v\n",
    "    y = np.exp(v) / np.sum(np.exp(v)) #softmax\n",
    "\n",
    "    return z, f, i, C_bar, C, o, h, v, y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Backward pass\n",
    "\n",
    "#### Loss\n",
    "\n",
    "\\begin{align}\n",
    "L_k &= -\\sum_{t=k}^T\\sum_j y_{t,j} log \\hat{y_{t,j}} \\\\\n",
    "L &= L_1 \\\\\n",
    "\\end{align}\n",
    "\n",
    "#### Gradients\n",
    "\n",
    "\\begin{align}\n",
    "dv_t &= \\hat{y_t} - y_t \\\\\n",
    "dh_t &= dh'_t + W_v^T \\cdot dv_t \\\\\n",
    "do_t &= dh_t * \\text{tanh}(C_t) \\\\\n",
    "dC_t &= dC'_t + dh_t * o_t * (1 - \\text{tanh}^2(C_t))\\\\\n",
    "d\\bar{C}_t &= dC_t * i_t \\\\\n",
    "di_t &= dC_t * \\bar{C}_t \\\\\n",
    "df_t &= dC_t * C_{t-1} \\\\\n",
    "\\\\\n",
    "df'_t &= f_t * (1 - f_t) * df_t \\\\\n",
    "di'_t &= i_t * (1 - i_t) * di_t \\\\\n",
    "d\\bar{C}'_{t-1} &= (1 - \\bar{C}_t^2) * d\\bar{C}_t \\\\\n",
    "do'_t &= o_t * (1 - o_t) * do_t \\\\\n",
    "dz_t &= W_f^T \\cdot df'_t \\\\\n",
    "     &+ W_i^T \\cdot di_t \\\\\n",
    "     &+ W_C^T \\cdot d\\bar{C}_t \\\\\n",
    "     &+ W_o^T \\cdot do_t \\\\\n",
    "\\\\\n",
    "[dh'_{t-1}, dx_t] &= dz_t \\\\\n",
    "dC'_t &= f_t * dC_t\n",
    "\\end{align}\n",
    "\n",
    "* $dC'_t = \\frac{\\partial L_{t+1}}{\\partial C_t}$ and $dh'_t = \\frac{\\partial L_{t+1}}{\\partial h_t}$\n",
    "* $dC_t = \\frac{\\partial L}{\\partial C_t} = \\frac{\\partial L_t}{\\partial C_t}$ and $dh_t = \\frac{\\partial L}{\\partial h_t} = \\frac{\\partial L_{t}}{\\partial h_t}$\n",
    "* All other derivatives are of $L$\n",
    "* `target` is target character index $y_t$\n",
    "* `dh_next` is $dh'_{t}$ (size H x 1)\n",
    "* `dC_next` is $dC'_{t}$ (size H x 1)\n",
    "* `C_prev` is $C_{t-1}$ (size H x 1)\n",
    "* $df'_t$, $di'_t$, $d\\bar{C}'_t$, and $do'_t$ are *also* assigned to `df`, `di`, `dC_bar`, and `do` in the **code**.\n",
    "* *Returns* $dh_t$ and $dC_t$\n",
    "\n",
    "#### Model parameter gradients\n",
    "\n",
    "\\begin{align}\n",
    "dW_v &= dv_t \\cdot h_t^T \\\\\n",
    "db_v &= dv_t \\\\\n",
    "\\\\\n",
    "dW_f &= df'_t \\cdot z^T \\\\\n",
    "db_f &= df'_t \\\\\n",
    "\\\\\n",
    "dW_i &= di'_t \\cdot z^T \\\\\n",
    "db_i &= di'_t \\\\\n",
    "\\\\\n",
    "dW_C &= d\\bar{C}'_t \\cdot z^T \\\\\n",
    "db_C &= d\\bar{C}'_t \\\\\n",
    "\\\\\n",
    "dW_o &= do'_t \\cdot z^T \\\\\n",
    "db_o &= do'_t \\\\\n",
    "\\\\\n",
    "\\end{align}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def backward(target, dh_next, dC_next, C_prev,\n",
    "             z, f, i, C_bar, C, o, h, v, y,\n",
    "             p = parameters):\n",
    "    \n",
    "    assert z.shape == (X_size + H_size, 1)\n",
    "    assert v.shape == (X_size, 1)\n",
    "    assert y.shape == (X_size, 1)\n",
    "    \n",
    "    for param in [dh_next, dC_next, C_prev, f, i, C_bar, C, o, h]:\n",
    "        assert param.shape == (H_size, 1)\n",
    "        \n",
    "    dv = np.copy(y)\n",
    "    dv[target] -= 1\n",
    "\n",
    "    p.W_v.d += np.dot(dv, h.T)\n",
    "    p.b_v.d += dv\n",
    "\n",
    "    dh = np.dot(p.W_v.v.T, dv)        \n",
    "    dh += dh_next\n",
    "    do = dh * tanh(C)\n",
    "    do = dsigmoid(o) * do\n",
    "    p.W_o.d += np.dot(do, z.T)\n",
    "    p.b_o.d += do\n",
    "\n",
    "    dC = np.copy(dC_next)\n",
    "    dC += dh * o * dtanh(tanh(C))\n",
    "    dC_bar = dC * i\n",
    "    dC_bar = dtanh(C_bar) * dC_bar\n",
    "    p.W_C.d += np.dot(dC_bar, z.T)\n",
    "    p.b_C.d += dC_bar\n",
    "\n",
    "    di = dC * C_bar\n",
    "    di = dsigmoid(i) * di\n",
    "    p.W_i.d += np.dot(di, z.T)\n",
    "    p.b_i.d += di\n",
    "\n",
    "    df = dC * C_prev\n",
    "    df = dsigmoid(f) * df\n",
    "    p.W_f.d += np.dot(df, z.T)\n",
    "    p.b_f.d += df\n",
    "\n",
    "    dz = (np.dot(p.W_f.v.T, df)\n",
    "         + np.dot(p.W_i.v.T, di)\n",
    "         + np.dot(p.W_C.v.T, dC_bar)\n",
    "         + np.dot(p.W_o.v.T, do))\n",
    "    dh_prev = dz[:H_size, :]\n",
    "    dC_prev = f * dC\n",
    "    \n",
    "    return dh_prev, dC_prev"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Forward Backward Pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Clear gradients before each backward pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def clear_gradients(params = parameters):\n",
    "    for p in params.all():\n",
    "        p.d.fill(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Clip gradients to mitigate exploding gradients"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def clip_gradients(params = parameters):\n",
    "    for p in params.all():\n",
    "        np.clip(p.d, -1, 1, out=p.d)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Calculate and store the values in forward pass. Accumulate gradients in backward pass and clip gradients to avoid exploding gradients.\n",
    "\n",
    "* `input`, `target` are list of integers, with character indexes.\n",
    "* `h_prev` is the array of initial `h` at $h_{-1}$ (size H x 1)\n",
    "* `C_prev` is the array of initial `C` at $C_{-1}$ (size H x 1)\n",
    "* *Returns* loss, final $h_T$ and $C_T$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def forward_backward(inputs, targets, h_prev, C_prev):\n",
    "    global paramters\n",
    "    \n",
    "    # To store the values for each time step\n",
    "    x_s, z_s, f_s, i_s,  = {}, {}, {}, {}\n",
    "    C_bar_s, C_s, o_s, h_s = {}, {}, {}, {}\n",
    "    v_s, y_s =  {}, {}\n",
    "    \n",
    "    # Values at t - 1\n",
    "    h_s[-1] = np.copy(h_prev)\n",
    "    C_s[-1] = np.copy(C_prev)\n",
    "    \n",
    "    loss = 0\n",
    "    # Loop through time steps\n",
    "    assert len(inputs) == T_steps\n",
    "    for t in range(len(inputs)):\n",
    "        x_s[t] = np.zeros((X_size, 1))\n",
    "        x_s[t][inputs[t]] = 1 # Input character\n",
    "        \n",
    "        (z_s[t], f_s[t], i_s[t],\n",
    "        C_bar_s[t], C_s[t], o_s[t], h_s[t],\n",
    "        v_s[t], y_s[t]) = \\\n",
    "            forward(x_s[t], h_s[t - 1], C_s[t - 1]) # Forward pass\n",
    "            \n",
    "        loss += -np.log(y_s[t][targets[t], 0]) # Loss for at t\n",
    "        \n",
    "    clear_gradients()\n",
    "\n",
    "    dh_next = np.zeros_like(h_s[0]) #dh from the next character\n",
    "    dC_next = np.zeros_like(C_s[0]) #dh from the next character\n",
    "\n",
    "    for t in reversed(range(len(inputs))):\n",
    "        # Backward pass\n",
    "        dh_next, dC_next = \\\n",
    "            backward(target = targets[t], dh_next = dh_next,\n",
    "                     dC_next = dC_next, C_prev = C_s[t-1],\n",
    "                     z = z_s[t], f = f_s[t], i = i_s[t], C_bar = C_bar_s[t],\n",
    "                     C = C_s[t], o = o_s[t], h = h_s[t], v = v_s[t],\n",
    "                     y = y_s[t])\n",
    "\n",
    "    clip_gradients()\n",
    "        \n",
    "    return loss, h_s[len(inputs) - 1], C_s[len(inputs) - 1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Sample the next character"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample(h_prev, C_prev, first_char_idx, sentence_length):\n",
    "    x = np.zeros((X_size, 1))\n",
    "    x[first_char_idx] = 1\n",
    "\n",
    "    h = h_prev\n",
    "    C = C_prev\n",
    "\n",
    "    indexes = []\n",
    "    \n",
    "    for t in range(sentence_length):\n",
    "        _, _, _, _, C, _, h, _, p = forward(x, h, C)\n",
    "        idx = np.random.choice(range(X_size), p=p.ravel())\n",
    "        x = np.zeros((X_size, 1))\n",
    "        x[idx] = 1\n",
    "        indexes.append(idx)\n",
    "\n",
    "    return indexes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training (Adagrad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Update the graph and display a sample output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def update_status(inputs, h_prev, C_prev):\n",
    "    #initialized later\n",
    "    global plot_iter, plot_loss\n",
    "    global smooth_loss\n",
    "    \n",
    "    # Get predictions for 200 letters with current model\n",
    "\n",
    "    sample_idx = sample(h_prev, C_prev, inputs[0], 200)\n",
    "    txt = ''.join(idx_to_char[idx] for idx in sample_idx)\n",
    "\n",
    "    # Clear and plot\n",
    "    plt.plot(plot_iter, plot_loss)\n",
    "    display.clear_output(wait=True)\n",
    "    plt.show()\n",
    "\n",
    "    #Print prediction and loss\n",
    "    print(\"----\\n %s \\n----\" % (txt, ))\n",
    "    print(\"iter %d, loss %f\" % (iteration, smooth_loss))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Update parameters\n",
    "\n",
    "\\begin{align}\n",
    "\\theta_i &= \\theta_i - \\eta\\frac{d\\theta_i}{\\sum dw_{\\tau}^2} \\\\\n",
    "d\\theta_i &= \\frac{\\partial L}{\\partial \\theta_i}\n",
    "\\end{align}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def update_paramters(params = parameters):\n",
    "    for p in params.all():\n",
    "        p.m += p.d * p.d # Calculate sum of gradients\n",
    "        #print(learning_rate * dparam)\n",
    "        p.v += -(learning_rate * p.d / np.sqrt(p.m + 1e-8))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To delay the keyboard interrupt to prevent the training \n",
    "from stopping in the middle of an iteration "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "import signal\n",
    "\n",
    "class DelayedKeyboardInterrupt(object):\n",
    "    def __enter__(self):\n",
    "        self.signal_received = False\n",
    "        self.old_handler = signal.signal(signal.SIGINT, self.handler)\n",
    "\n",
    "    def handler(self, sig, frame):\n",
    "        self.signal_received = (sig, frame)\n",
    "        print('SIGINT received. Delaying KeyboardInterrupt.')\n",
    "\n",
    "    def __exit__(self, type, value, traceback):\n",
    "        signal.signal(signal.SIGINT, self.old_handler)\n",
    "        if self.signal_received:\n",
    "            self.old_handler(*self.signal_received)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Exponential average of loss\n",
    "# Initialize to a error of a random model\n",
    "smooth_loss = -np.log(1.0 / X_size) * T_steps\n",
    "\n",
    "iteration, pointer = 0, 0\n",
    "\n",
    "# For the graph\n",
    "plot_iter = np.zeros((0))\n",
    "plot_loss = np.zeros((0))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Training loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXEAAAD1CAYAAACm0cXeAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJztnXecVNX5/98zW9llWfrSWaUcKyAdVMTuzxpLNErsmmiiiSUS800s0URjiJIETUyssSvRhCh21IAUaQoieihSpC8LLNtnZ2d+f9y5d+7M3Cm7zMzuNc/79eLF7C0z587c+znPeZ7nPMcTDAYRBEEQ3Im3rRsgCIIgtB4RcUEQBBcjIi4IguBiRMQFQRBcjIi4IAiCi8nN5ocppQqAMcB2oDmbny0IguBScoDewBKtdWP0zqyKOIaAz8vyZwqCIHwbOBb4OHpjtkV8O8Dzzz9Pr169svzRgiAI7mPHjh1MmTIFQvoZTbZFvBmgV69e9OvXL8sfLQiC4GocXdAS2BQEQXAxIuKCIAguRkRcEATBxYiIC4IguBgRcUEQBBcjIi4IguBiRMQFR/bU+ii/fTZvr9rR1k0RBCEBIuKCI+sragB4bN7XbdwSQRASISIuOFKQa9wajX4pcSMI7RkRccGRgtwcABqbAm3cEkEQEiEiLjiSl+MBwNcsIi4I7RkRccERr8cQcbHEBaF9IyIuOBIM/S8+cUFo34iIC44Eg4aMN/rFEheE9oyIuOCIaYn7RMQFoV0jIi4kxB8IJj9IEIQ2Q0RccCQo2i0IrkBEXIiDqLgguIGky7MppYqAp4EyoBi4BzgPGAVUhg6bprWerZQ6F5gKFAIztNZPZqLRgiAIgkEqa2yeDSzVWv9eKTUQeA9YAPxCa/2GeZBSqgSYhiHuTcBypdQrWuuaDLRbyDDiThEEd5BUxLXWL9n+7AdsiXPoGAyxrwJQSs0HjgXeOtBGCtlHNFwQ3EHKq90rpT4BegGnA7cBNyqlbgN2AD8GegMVtlN2hY4XXIhY4oLgDlIObGqtxwHnAi8BzwK/1FofByzB8JP7ok7xIAadIAhCRkkq4kqp0UqpAQBa6+Whcz7XWi8NHTIbOBzYDvS0ndoL2Jbe5grZIij9ryC4glQs8YnATQBKqTKgBHhYKTUstP8YYBWwGBimlCpVSnUExgHz0t9kIRuIO0UQ3EEqPvG/AU8ppeYBBcCPgP3A40qpOqAauEpr7VNK3QnMBQLAPVrr+gy1W8gwIuKC4A5SyU5pBC5x2DXW4diZwMw0tEsQBEFIAZmxKTgiPnFBcAci4oIj4k4RBHcgIi4IguBiRMQFQRBcjIi44Ii4UwTBHYiIC45IYFMQ3IGIuCAIgosRERccEXeKILgDEXHBEdFwQXAHIuKCI0ExxQXBFYiIC4IguBgRccERscMFwR2IiAuOiDdFENyBiLgQh7CKi39cENovIuJCUkTDBaH9IiIuOGIXbtFwQWi/iIgLjtiFOyCmuCC0W0TEhaSIiAtC+0VEXHAkwp0iGi4I7RYRccERe0aKiLggtF9ExAVHxCcuCO5ARFxIiki4ILRfcpMdoJQqAp4GyoBi4B5gEfAM0BnYAkzRWjcqpc4FpgKFwAyt9ZMZareQYezGt1jigtB+ScUSPxtYqrU+Djgf+AMwDXhKaz0e2AhMUUqVhLafBhwNTFVKdcxIq4WMY1/ZJxhow4YIgpCQpJa41vol25/9MCzvycB1oW2zgBswxHyp1roKQCk1HzgWeCt9zRWyRsRkH7HEBaG9klTETZRSnwC9gNOBeVrr+tCuXaHtvYEK2ynmdsHlBETDBaHdknJgU2s9DjgXeAnw23Z5MOw2X9Qp5nbBhUh2iiC4g6QirpQarZQaAKC1Xh46pzYU8ATD2t4GbAd62k41twsuRCb7CII7SMUSnwjcBKCUKgNKgDeAc0L7zwNmA4uBYUqp0lBAcxwwL+0tFrJCUErRCoIrSEXE/wb0UkrNA14HfgTcB/xQKbUE6Aq8rLX2AXcCczHE+x6b31xwMeITF4T2SyrZKY3AJQ67JjscOxOYeeDNEtqaoGSnCIIrkBmbgiORgc02a4YgCEkQEReSEhAVF4R2i4i44IgEMwXBHYiIC45InrgguAMRccEZyRMXBFcgIi4kRSxxQWi/iIgLjtjTCiWuKQjtFxFxwZFI41tUXBDaKyLigiORi0K0XTsEQUiMiLiQFPGJC0L7RURccMQu26LhgtB+EREXHLFP9hFLXBDaLyLiQlJEwwWh/SIiLjgi7hRBcAci4oIjkdkpouKC0F4RERfiID5xQXADIuJCUkTCBaH9IiIuOBK5ULLIuCC0V0TEBUdkZR9BcAci4oIjQSlFKwiuQERcSIoENgWh/ZJ0tXsApdR9wPFAHvAAcAYwCqgMHTJNaz1bKXUuMBUoBGZorZ9Mf5OFbBCU7BRBcAVJRVwpNQkYobWeoJTqCqwE3gd+obV+w3ZcCTANQ9ybgOVKqVe01jWZabqQSSJ0WzRcENotqbhTFgAXhl5XAflxzhsDLNVaV2mt64D5wLFpaaWQdSSwKQjuIKklrrX2A6Y1fTXwZuj1jUqp24AdwI+B3kCF7dRdQK/0NVVoK4JiigtCuyXlwKZS6hzgWuAm4Fngl1rr44AlwD2AL+oUDzIQdy2RVQyTH//Exxsov302jf7mDLZKEIRoUg1sngrcCZyitd4HzLHtng08CrwA9LRt7wV8kKZ2Cm1IKoHNhz9YC0BNg5+CjjmZbpIgCCGSWuJKqVLgIeB0rXVlaNsrSqlhoUOOAVYBi4FhSqlSpVRHYBwwLzPNFrJKCpa4eYjH48loUwRBiCQVS/wioAvwslLK3HYn8LhSqg6oBq7SWvuUUncCc4EAcI/Wuj4DbRayQGurGIqEC0J2SSWw+Xfg7w67xjocOxOYmYZ2CW1MZJ54CsdL9EMQ2gSZsSk4IgWwBMEdiIgLSUnNEjcOErkXhOwiIi440lJL3DxCrHZByC4i4oIjLZ51HzpIZncKQnYRERcciZzsk7oyy+xOQcguIuJCUlLyiZv/i4YLQlYRERcciXCnpOITNwObIuKCkFVExAVnWriyj3mI1B4XhOwiIi4kpWU+cUEQsomIuOCIPUCZkiVuZqdIeoogZBURccGRltZOkawUQWgbRMQFRyIDm6mfJz5xQcguIuJCUvwpuEiCMtlHENoEEXHBEbtB7WvBaj0y7V4QsouIuOCI3cftaw6kcLyBWOKCkF1ExAVH7AZ1Y1NyEQ9rvqi4IGQTEXEhKalZ4oZ4iyUuCNlFRFxwxK7Fjf4ULHHzPBFxQcgqIuKCMzY19qUg4uHsFFFxQcgmIuKCI6YU53g9NKaQnSJVDAWhbRARFxJSkOtNKbBpphaKJS4I2UVEXHDE1OKCXC+NKQQ2BUFoG3JTOUgpdR9wPJAHPAD8F3gG6AxsAaZorRuVUucCU4FCYIbW+smMtFrIOKZlXZiXk5olHvpfLHFByC5JLXGl1CRghNZ6AnAKMB2YBjyltR4PbASmKKVKQttPA44GpiqlOmaq4UJ2KMj1ppZiGIz8XxCE7JCKO2UBcGHodRWQD5wA/Ce0bRZwKjAGWKq1rtJa1wHzgWPT21whW5haXJCbQ2NT6tPuxRIXhOyS1J2itfYDNaE/rwbeBM7WWteHtu0CegG9gQrbqeZ2wYVYPvG81CxxE5nsIwjZJeXAplLqHOBa4CbAZ9vlwTDcfFGnmNsFF2L+cIW5sT7xpuYA+xuakpwpCEI2SEnElVKnAncCp2mt9wHVSqmi0O5ewDZgO9DTdpq5XXAxBXnemDzxn7z4KcPuftfxeLHEBSG7pBLYLAUeAk7XWleGNr8NnBN6fR4wG1gMDFNKlYYCmuOAeelvspANzOwUp8DmW6t2RBwTeV7m2yYIQphUUgwvAroALyulzG2XA/9QSt0CaOBlrbVfKXUnMBcIAPfY/OaCS8nP9caddu8PBMnL8URsk8CmIGSXVAKbfwf+7rBrssOxM4GZB94soa0xtTjX6yVeXNPnD5CXEzmYEw0XhOwiMzaFhOR4PXGtaycLXVb2EYTsIiIuOGLWB8/xemiOilaaLhSn1MMg8OLizZTfPjtBBosgCOlCRFxwxDSoczweAlEinus1bhsnSzwQDPLExxsA2FHVkNlGCoIgIi4kxuvgTsn1JrDEg2GXiidmryAI6UZEXHDElO1cr4daXzPlt89m5tJvAMgx3SlxLHFBELKHiLjgiOVO8Ybt6cfnGW6SRO4Uu4R7xBQXhIwjIi44YgY2c20ibopywsBmMCgT7wUhi4iICwmxW+Lm69xE7pQANnNcTHFByDQi4oIjpjvFaxNxb8gUT9WdIghC5hERFxJid6eYgm5a5I1xApvhRZNF0gUh04iIC46YApwTYYkb/5vC3hQnxdCkWURcEDKOiLiQkBxbion52qyXEm/avdkBBGR9ZUHIOCLigiNWimFOrDslN8m0e1k0WRCyh4i44Igpv3ZL3HSn5CWZdm8SXXNFEIT0IyIuOOI02cfMTjG3ObtTwueKJS4ImUdEXEhIwjxxB3eKXbhFxAUh84iIC444zdg0LXHzf39zYpGOt5iEIAjpQ0RccCTsTgnfIqaem25yv0P6iZEnHrReC4KQWUTEhYTYV18zLXAzXtnkYIlH+MQlsCkIGUdEXHDElF+PJzbF0MwD9zv6xO2BzYw2URAERMSFeJgzNh0m+5ipg34HlY5IMRR3iiBkHBFxwZEghu/bKTvFFGqnaff2CljiThGEzJObykFKqSOAWcB0rfXDSqmngVFAZeiQaVrr2Uqpc4GpQCEwQ2v9ZAbaLGQJD5FVDE2j3NRmp+wUSTEUhOySVMSVUsXADGBO1K5faK3fsB1XAkzDEPcmYLlS6hWtdU0a2ytkCftCybH74rtTgrb9zYEg976xmiP6duLco/plrK2C8L9MKu6URuB0YFuS48YAS7XWVVrrOmA+cOwBtk9oI4IE8Xg82Axxy7K2LPFQimEwyvq210554uMN3Pzyimw0WRD+J0lqiWut/YBfKRW960al1G3ADuDHQG+gwrZ/F9ArTe0UskwwGOtOMQOaASs7JdYit3tQxCUuCJmntYHNZ4Ffaq2PA5YA9wC+qGM8yEIvrsbjiXSnWCIeiAxs2n3jQSmAJQhZJaXAZjRaa7t/fDbwKPAC0NO2vRfwQeubJrQlVhVDR0vc+Nu0wJtsMzcNnzih40TEBSHTtErElVKvAL/RWq8EjgFWAYuBYUqpUqAZGAdcn66GCtnFcKd4Itwptb5mym+fbf3tZIkHAjLtXhCySSrZKaOAB4FyoEkpdQFwJ/C4UqoOqAau0lr7lFJ3AnOBAHCP1ro+Yy0XMk+UO2VzZV3EbmvST3OkJR7en9HWCYJAaoHNZcBkh11jHY6dCcw88GYJbY1pTduzU/bWRYY9TAu8KWDPThF3iiBkE5mxKTjjkJ0Svbq96QuPsMSDka4VQRAyi4i44IjTtPtoLEs8Ijsl7FKR2imCkHlaFdgU/jfw4LHKzzphBTZt2Sm/ffNL67WkGApC5nGlJX7V00t4dtGmtm7GtxrTLZLQErcCm85ibV+D85s9ddz00qfsb2hKYysFQXCliH/w1S7u+Peqtm7Gt5pg0HCnJNBwyxfuWM2QyDU431q1nX9/to373/yqxW1pag5w28wVMdkxgiC4VMSFzBMkFNhM4E7xJ6grDtDkD2/vXJQPwKqtVS1uy+INe5i5bAs/f3Vli88VhG87IuJCXDweT4qBTWdL3L7dPLY1aYemb90rd6sgxCCPheBIeKHkRJZ47IxNOxEiHjq2NcFOU/gTjQpawvaqem6buYKGpua0vJ8gtCWuE/GgpK1lhSDBpO6UJquKYXKfuHlsa36+QApB1pbw4LtrmLlsC7NXbjfePxDkv2sq5N4SXInrRNxphXUh/QRDTvHUApvJLXHzdbAVhS3Nt3FaoKI1dO6QB8DumkYAnl6wkcufXMzbq3ak5f0FIZu4TsQl9zh7JJPMpiQphvbApin4B+ROSZMl3qOkAICKakPEN1XWArBzf0Na3l8QsonrRLwpztBdyAyJRNfvMNnHTpODO6U1fbDZSaRJw+lYaMxxqwhZ4mab0tVJ2PE3B9rMTbOruoG5ayqSHyi4GteJeDyrT0gvwaCxPFu89EEwxM/fHIg/2cchsOmLqr/yh3c0f3p/bcK2NPqNAGS6fOJmx1RZYxT0Mi399Es4DP7lW0z9Z+LUyOqGJh75cF3aR5nffXQhlz25WHz933JcKOLusMS37qunMmTpuRWPBzoWJK7M4GsOpGSJm0JvCvuiryspv302D3+4junvr0n4GWbhLU+UTzy6Q0iV6LaY2hn9/gdCdUOTda/OXLYl4bH3v/UV097RvLc6vT75TaHJUeKC/HbjOhFvcskNefTvPmDUb95v62a0GvNb7t+1iFevn8APJx0csd80ipv8wQSBzfD2+lA6X2PofzMzJBXMc+yBzUZ/M0N/9Rb322q1pEr0MnPBNKcwBoNBjrz7XW5LYoGb7K83ShFEV4k8EL7YFp5UJckA325cJ+LNckNmBXOhZIBRA7tS3r04Yn9hXg4Ajc3NcUdHdku83mcIsWn9tqTCoSludndKXaPxfn+b+3XK72O1KxDZhnAeeovfyhGzk/jXp1vT84ZxqPc1c9aMj/nsm30x+86c8bH12ueS0avQOlwn4hLYzA5BghHuhV6dCiP2myLu8wfi+s3tlqXpUmhoCrB0454W+WnD7pTwtgMRpuaoIGvQcqe0+i0jSBRHSCert+/n861V3PHvVdz88md8tWO/tc/+9cabUdteqahu5LezV7vGddrWuE7EJbCZPeya1qU4P2JfYa5x6/j8gZTyxO1c8OjCFvlpzcCmXZjs/vCWLj7hj3KnpNsn3tIOprV3dEHoN9A7qvnXp1u5/rnljselW8TX7qzm8y0tr4GTKnf9ZxWPzdvAvLW7M/YZ3yZcJ+JusyrcSrShfFAcd4qvOUC9z+9oxSb6rVryMzY2xVZLtAvlEx9vSP3NCLs7zI7AHBWkK4sjW4aG6QYyvwvTtx6NPV8/HZw8fS5nPfxx8gNbifm7pDqiqapvsiZu/S/iOhGXSHt2MFf2MSntkMeG+0/n7OF9gEh3yt66JkpDsyDtJBKP1rhTTBEPBIIRgvXbFgY3TZdcODsldoWiaPzNgZSzYeJ1Xh98tZNhd79Dnc/vuL+lI4Ho9thrtefaHPzZ9Ik/u3Aj5bfPblHgOhaj7aneI+Pue5/RLk4iOFBcI+LrdlVz5ox5zFz2TVs35X8C4/mJFBWPx0NujrGtMM+4dVZsqeLZRZusqex2ahqdxQpaVs3QdKeYFu6vZq3i3L8siDgmmf/0rc+3s3qb4TM2feKmCJp2QSID4ZxH5jP0V2+l1N54Iv67t75if4OfzXui6qJbn98ysY1Z89TWCdmDwNkavX65fT93zPoCgN+/E64b/+vXv6D89tkpv4/Zly3bvJfy22ezoyrxTNqGprYZnX+1Yz/lt8+27qu2IiURV0odoZRar5S6IfR3T6XU20qpRUqpfyqlCkLbz1VKLVRKfaqUuiqdDb115kpWbd3P26t2WttkId7M4mQYmml+piVuLs7hZEVu3Vcf973j/XSbK+tYt6s6Ytv+eqMzMNNLZzlkfSRLz7v++eWc/ud5QHiYHmuJx3+PL1rwoLY2pa+lbg+zc7OzbNNeINISz5aI19uqQtbaOvCn5m8EYjvJHz2/jP+s2BbzPmbL/7HAOG/B+sz4xjfurrU+IxFvr9pB+e2zrfIM9u0Ab3/RtjV3koq4UqoYmAHMsW2eBjyltR4PbASmKKVKQttPA44GpiqlOqarodcccxBF+TnsrfNZ29yaqVLn8ye0UtsHzoLijRJxk+oGP4f0KrH+TjZJKDrFcHuVIfiTpn3ISQ/Ntba/tnyL9ZA0hYR65MAuMe8XLeIL11eyy6EWyoL1u2N94qF9qfhg1+6sTnpMvFGB0+DjtD/OZfbnhuuhpW4PJ/fOko17gOSWeDAYTFqKd1NlLdsSTFqLdnfYu/Hqhtj7O7rTefPzHfzkxU9jjovO18/UhNOLH1vEXf/5wkp/jce/Q0ZDSzrybJKKJd4InA7Yu8zJwH9Cr2cBpwJjgKVa6yqtdR0wHzg2XQ09a3gfJg7qFtGbu3USw1H3vMcRd73T1s1IiD1P3I5ZX8R0p5jUNDbxzNVjrb+7FMe6V0wGdC2KGUVNuP8DPtS7IrY1+pu55ZUV1t/mzNDoDsQ81s7Fjy1i0rQPgcgR2yWPfRJTAsAUo1RS2k6ePjfpMfHE2GyF3eL+ake4U2jpDFSn0Ud1yC+elxP+fXwOFv4TH2/gkDveZm+tL2afyXHTPmJiaNLa4g17YvZHX6f9Uxr9AZZs3BPh/69u8FsdayLXlanhTQewkEgqmAXQnEY0dszKm4kiFm1plCUVca21X2sdPS4usW3bBfQCegP2ajvm9rRhVp8zaUrjDLd0kiwgk86ZeZnEyZ2SH/KJR/shG5oC9CwptKzx6LxykwtG9cPfHHB8MNfYBC0YDLJ+V+Tw1WfVJI89194ec7+5rSHqIY2edm8KSrryu5Nlp8QTjZa6PRodfMGmBRxtib+4eHPELM4XPtkMwK7q1LI69I5YK7SusZlgMMg/Fmykqq4pRpi/++hCfvrSZ9bfN774KWPvm5M0SGzed83WjNrYY95bvTMmeBrvudtR1eDYaZi/d32SEUk4BdV5/7pd1Rxx1zu8tjxxeYVM0drApr379mB0wtFdurk9bXTvGCni7dUlka3JHpkkXj90waj+QGzK4YPfHR5x3s9OUY7n5+d68TUHHB8qs7ogwP4GP/vqI28p01J26gTtwmjf//aqHdQ2Rj6k5mc3B4I0B4KW6NrFd31FDeW3z44QvlRxEmO7RRo9Akh0XiIao44vzs+xRNzuE7/91ZX84rXPOePPH4drwIdGIzWNzmmJ0dh/G5NV26o46Bdvctd/vuDaZ5c6dirz1obtOtOar27wJ3TleKJsXrPDn7e2goXrKwG49pml/PiFyLx4pxHQnlof4++fw+/eip/BlMydEn7/+J0EwH2tKAGRDlor4tVKqaLQ614YrpbtQE/bMeb2tBFtiVcmGAq2Jd+GNEhjZZ9Y0+PIfqUs/dVJXHXMQda2eVOP5/xR/YDwA2cujBxNfo6XxqaAY0aB3e1RWdNopRGOO6grw/t3tkTWSSzs2+wCMXdtRcxDau9kqxuamPOV4caxx1j+GZph+v7qSBePncqaRtY4+Mid3Hzff/wTS7TNTiZadExxf2zu1zw+L3k5gcYoIezZqTBsieeEf7tttuyOPaGYkunS2VeXmog71ZX5zRth0Vq8YY/jCMPpd95X3xQzOrIT/VFm/OTSJxZz8WOLIvfZfkunz9pTa4w0zN/YiWSWuNnX1sdJDTV/7t01vjZ59lsr4m8D54RenwfMBhYDw5RSpaGA5jhg3oE3MUyPKEu8IsWhYLb5NkxICgbjDx+7dyywZmwC9C4Nu07CS6k5BzcLcr00NgccLbEam8W8p9ZnCcz0i0bQv0sH63t1FotmXlu+hbMf/pgbXggHy/bXN1Eb9fDZqy7+3VZ7xV6XZ8tew1vYu3OsW8j0I1/59BJOmT435vd2+v2Xbw7XNzFFvC5qhGBaer9980t+M/tL9iQxUqI7gR4dCyyfeG6cVaX31jZFtNH8jp+ev4EL/rog7ui2LtQR2jvabVHZR6n69PfV+RKmBUZnOjl12iZ2AY7u1CDcoSYqbhZ9L+7c38Cf3l8bUyAtekRnGhV1tu/sy+3ZD36mkp0ySin1EXAF8NPQ698BP1RKLQG6Ai9rrX3AncBcDPG+x8GXfkB0j7LE2+ssrW+DJQ6JAzn5NhHPtQXR7FPY44m4zx+wFmSwY09Lq6z1WcPvzkV55OV4LUvZ2Z0S4JZXVrBySxUfrwunpFXW+GKCcm9+Hk4Js99Ddgt9y14jl9vp4T/q3veAcFBy5ZbIAlTxSvOG22qIQbRgLly/m//aFnEw0wVNot0v0eLWqUOuZYl74iytd+of5/Kn99daHcC++iaCwSB3v76apZv2xqTRmdT5mvm6oibiu6+Oan+qsZ599U1x3SkT7p/D61Fph9FWu/17sAuoU8dgurESFTer90WeN/WfK5n+/hpWhH5X89OiLXbzGvbZJp59Ez0HIAskzgMDtNbLMLJRoonZprWeCcw84FbFIdoS/8VrnzN/3W4evmQkW/bWUZSfS9di52G8E+98sYNv9tRxzbEHJz+4BdjFwFxcwW0k64bsIm7HvjJ9x8JcsBkmL/1gPEtDKXBmrWs7dlFbtmkvr4VSuzrk5VCY57UeNieL77InFzu2Z+HXlSz8ujLudWzcHW7Hf9dUMH/dbo4e3N1y5cSzLu2TVz77popRA7tafztlg0D4O230GzGB5ZsjRXrFliout13Htc8s5Z/XTWB0eVe2V9Vz3LSPmPnDCQzv39l6HzslhXlUNxodS1NzgHEHdXO89unvryEv5G6pqvNZow4w/NWBQJB1FTUR5yzeUMm9b6zmvJF9Ha8N4O7Xv4i7z05VXZPj5DCA7Q4Texqi3GHmvAGAvTZ3kNMIzRzdJbLEo8XZ/NuqvBn6nmujOi3zuCpbG6I7tkAgyPT31/D98QMpixPsP1BcM2MTYn3iAG+s3I6/OcAxD3zIcb//kGAwmHIw6ofPLuM3s9MfjLAHyNxqlRvulPg3fn6OlwtH9+OFa8dFbDc7xJ4lBZx8WBkA108exLNXj2X8wd3iij8YD4lpMdndHB6Ph9IO+VTV+wgGgylbfEX5samI0azZVc3EQd0o61TAht21THn8E6obmqwRRbL0MyAmHz2uJR56T58/wPT31kRkbsRjcajT27K3Hp8/EDFcj+5gSjvksa/WsKyb/EGKC+Jfv+lm2F3ri3jP6gY/c9dWcEpUKqWZI/3a8vjldZ3861fbYicmVfVNLZplWd8UWe54V3X4+zbnFxivG2JSV03hTXQvR4u4WVis1hc5YqrzNXPR3xby4+eXR5xnd2vVROXHr9xaxYwP1nGrLVU23SS1xNsTxbbhedfifMtnOPiXxnTo6kbXY56BAAAZhklEQVQ/zy3axB2zvuCFa8cxcVD3Nmmn/SFuag6Sm1xL2h3JVqX3eDz8/oLhMdsvHT+QS8cPBOC2UxQXjxnAgG5F1v78kOvlyL6lfL41srOt9fnJz/U6PuBdivJoag5S62tOSVjB6Eg2Olj8dvbVNdGvSwcWrA9brEfe/a71utGffI3M6NhMsphIoz/A3LWprX1pfl+mkFRUN1Ln83PHv7+IcUkN6FpEdaOfylofTc0BivKTP94L1u2mp804euDtr+jbuUPMcXZrvSUc0bdTzLY9tT7HwGa8GdgNTQFLUAG22HzxK23VFC97cjHHDunOs1cbhsXiDXus5AcPhkGV4/VQ5/Pz5fZwQLrB18wd/17FcUN7cNJhZZaIm7EP87t/ZuFGq3N/hFhfOsRa6ybRmVbpxFWWOMDTV47ho59NpsQh5QmwhMFpuB6PdNcttlviiWaVBoNBbn75M87OYEW4A+FAvUBerydCwCHsAhh3UNeY46sb/DEjlzduPAaALqFsl/P+Mt8KsgGMdpi9adKzJP7w9fA+nawJS/Fy2sGwdqN909FE51onm4Tm8wfiPuzR7K9vIhAI8kxoenhFTSMvLv6GV5dviVgEuVNhLoN7GhOk1+2qwdccSGiJA0wc1I2NlXX88f211vO0bldNhF8eEt8HV0ws53tj+sdsP2t4H2b9+GjHLKVNlbURQciGpmYe/mAt767eGXMsGBav3dW2cXfYb//Qe5FL+81bu5tgMMiandVc+LeFVlmI1dv3M+j/3gTg/L8u5Py/hmvv1DT6eXbRJq55ZikQdhVW1vpoaGq2Ps9+ay5cXxmT9eT1wOY9dZw6fa5VOsI8Jt2VJCM+N2PvnCEmq56Udy+2estoWqPH9inCNY1+K6jVWuw+8egJSXarzh8I8q9Pt0ZYE9lgw+7auAEsiwzdc2alvU4OPtGK6sYYATTrmJcWGcev2VkTIeJPXD4m7mf16BTrfjPpXJRniV6v0ljL0+Sh99ZwwaML4+6HyOE9xLfEvw6JwesrtrG+Ivz9333WYXHfe3etj9dXbuNDbQhrRXVjTCAVjO/JvJ4vtu2nusGf1BK/+eSh1uux5bGdKsCMi4+iX5f430+34nyOGtA5ZvsfLxrB8P6dHX3f6ypq+MQWbP7ZzBX84d01XPfcMsfPaGhqjnBTbNid+N7dub+Rr6N8+tZn76qJySBZZXO/bthdi2l37a3zsXJLFf5AkCP7lkacc/FjiyJGB16PYWj8c/kW9M5qHpu7AQhb8ZmsJOk6ETc5/pCejtsrQ3mhTc0BNlXWcsr0/7LToYaG3Y95woMfMXPpN8xeuZ0rn1rMMQ98yPLNe1m7s5qXl2y2jpv2zld8utnZKvtqx37rPe3ulOiJP/Yfs7UL/R4ox//hI46b9lHc/RXVjWysrE3bSjd2zA6zpDCX+849kpEhAeiQlxMxBd3EzHDpEifvvCAv/i2cSHwAuhYbIt+5KH6JgFRYs7OGRz5cx6zPtnLpE5+wfV/iqnuro0Tk++MH8t7NkyK2/eTEIeTneNm+rz7i/n1r1Q5mfRbO3jBdH5eMHUDv0kIGdivi3jdWA8bU+x8e5xy093hA9SrhxNBzZAZLoynI9VpBwbOH9+GG4wdz+YSBvPajiQAcM6S7YxkEc8aokyW+bleNVRSrpDCXN5KUrf1kw56IjjKZi2x9RQ079ztnrt05a1XMNruf/7Q/zrVSUjfuruXPc9ZSUpDLgxfGug7NID0YVnrHwlwrp7xzcR7/+nSL9XmZfNZd5RO3c9spinU7a5jz1S7GlnelptHP6u37+ShksTw1fyOrt+1nzc4axt03h4cuHM55I/sRDAZ554udEYV39tY1xSxqe95fFpDr9eAPBLlwdH+amoM88uF6HvlwPf+8bgKPzfuav0wZRY7XQzAY5LQ/hlPiH79stPX6+ueW8fRVY+lUaAiF/cfMZD75/oYm8nO8MQ9YKilQ4++fQ3MgyKAexUmPbSljyrvyzMJNjBrYhWH9OnPh6H5WFspUh4WFTRGPDlKefmQvpp56iKOAmIwaEN/VUudr5oqJ/Zm7poJBPVpfp21seVcWb9zDtHe0tc1ckeaDW4/jhAf/m/Q9cnO8lhVtckSfTpx8WBmzP98ed+IUGDNn377pWDoW5OLxeDj18F5WUFjv2M9TV47lqY834msO0LOkwHL9jBrQhU6FeVZwbmiZ83fgaw5YrsmLxvTn6MHhONPG350BJJ667+T2tMc8TjmsF68mma5eUd0YsWqR6d54/LLRLNm0h721Pl5ZGn6P9RU1seV+QyxYX8kxg7tbaaglBbkRGSWN/oD1+5nunV+dcajjPRIIwqmHl/HOF8Zx9pTaToV53PxyOJgplrgDuTle/vL9kcy8bgKvXDeB2T85JqKK3obdtby0JFx7/JZXVvDTlz5l5rItXPfcspS+VNOKrmn0c8sr4UyCHzy7jHe+2GlZB9FBn6U2H+ryzft4btEm62+7iC/ZGD4uWUW5ljLs7nf5ziPzY7Yf+/sPk55r+qXzMxCRPWt4Hz6782SG9TMsv9wcLxeO7s+phzuX2TEtukE9OnJ4n0785ITBgBEci168ORrzMwrzvJbvfETI4txX18Q5I/qy4s5TULb7JhFOFnuX4jzuONPZHdK7tEPc0cDU0yLLEtizJ564fDQnH1bGj44fBMCiBCmSHQtyKSnMs84/9fAya1/09/PCtePpErqGyaoHACeELPFDexsByA5RnWJdYzPXTx7EeSP7MnFQN8c2JOpIuxXnc9NJQ3jqSme3l5Mr5k/fGxGT9WS6JYb07Mg3IXdnaVEev/h/h3LfuUdGHLtuV01Cd+GZw3pz9OBuPHLJSFbefYq1PV765HdH9yfH6zzvYYzNDfW1zUUWnUGV6tT+1uBaSxygIDfH+hI9Hg/HDunuOCQ3mfXZNlY4rAyejG37GiKGfKYFXVXfRO/SDqyKyrKInsn2zZ5663h753FtKJACxk0a/TDs3N9A944FEcWMWkKi7yIV4sUdDhQny7K0Qx7nj+zHq8u3cPNJQ5n+fmTAqkN+DrN/cizVDU08Ovdrrjkm7CYYOaBzxIzI/BwvHo+RwTT9ouEc0qsTvwoFuC4e25/Pvtln+VVNX/vDlxxluB+edfbLglEPxnwfkxyvhzOO7M29b6ympDCXK48+iD/PWQtAXo7HCnLPv/0Eo6hXRS0DuhZxUPdivB5PTA0agBMPNYR4cM+OeDzOudMm0TVNjurfhRuOH8xRAzpHWM0AfToXMrBbMXvr9jGiv9GpXXX0QZw1vA9lnQpZfsfJ5Od6eX3FNqobmnhm4SaOP6QnFzqk9tqJFv5jh4Q/1+PxcNNJQ6lyWDruhWvGRcQ3TPp27sCogV3Iz/Vy7oi+vLw0bIyddkQvZnywDgjPFLZPNjukV0lCdwpAWWkhz18z3vr7pR+M54mPN/CTE4Y4plCaq1aVFOZS0+jnxhMGW22w+8rHHtTVCgpvjupEahr9VNY00q1j4u+yNbhaxKM56dAyHpu3IeExdn/a6zccw9MLNsYM5yYcHDlJYt2uyCCJ+WDO+XIXvToVxqTKRdfTmPPlTspvN3zrd8UJYtU0+CMKfM1bW8GlTyzm+smDOHt4H15e8g1XHX0QA7oV8dyiTTy3aBNv/fRYx/zXlq45GY9EOd2ZwMym6NQh/m1ZUpjHmt/8v4htz149ji+27ac5EGRMeRdyvB7rezn3KKOmyy/POJT/e+1zzhreh5+/+jl9SiMzUs4cZiw7N+fW45gxZy3/tvmdf3324Zx8WBm9Sws5tHcnzv/rAlRZCReP7c8Zw/rQo6SAR78/ijHlXehanE8wGGTGB+vI8Xo4Z0Qf/jb3a7oU5VGUn0u/LuFsneuOG5Tw+yjIzaGspJAdDjEdk2g3k9fr4WenRlr5f5kykr/NXU9hbg4/PWkIt7+6kpEDO1vHm5NQzIlyF48dAMAPJiVun4mZ5XNw92I++Nlkx2NKQlasXQB7dipgvy1gWd6tiI2VdXTIz8Hj8fDlPafh9cD2/Q1WJs74g7sx44N1eD04Tp4Z1q+Uf326NSJAPqa8S8Sotywqa2n8wd0Yf3C3iCy1V6+fyKbKWvp3Df9epiXe2xYIt79+ZMpI/M0BTp4+N8KoMNE7qpk4WEQ8IeMO7sb7t0yiusHP4/M2WMX243FYn048eOFwvthWFWG1ThiURMRDgctp72i27quP8TNHW8B2n2G8Smf2FKp1u2q49Alj5p5pFT23aDM9OxXwo8mDLWtww+5aVmzZxwmqjA75OZbomoEtk2AwyKbKuqTuh2gyZYnHw5wHUOdr5qkrxqRcVqG4IJexDimLdkYO6MLbNxnBww9uPc6xKh8YbpshZZHulcsmDLQ6BXNQVJDn5YqjwxNZTjsi7A669RTFraEqjj8/7RBuOGFwSjnbj182OqZdA7oWxYh4SUEut5wylF+/vjolN9xJh5VxUmji1fGqJ5/830lJz2kJ5si0JM4sTDA6C9OHbop4j46F5OWE86f7dTFE3AwOmiPQZ64aa82QPWpAZ3K9HnqUFETUTDc56dCyCP84GBltd511OGfOMFJ5y+JkLeXmeLnjzMM4sm8powZ2YVRU+qrp3+/WMTyS7Gl7L1PkOxbkWiO9PqWFbKtqwOuBj9ftZuLg9M9d+VaJOMDgnsYDeOKhtTEibn6hYET/zZvEXmjooO7F/GDSwXy+tYr3QoGN6CnI9l5+5ZZ9bN/XwIWj+5Hj9fLe6p0JxSdeDnFFTSP3v/klB/coJsdWvKihqdmabl7XGDlz7ZxH5lvZHpOG9uBPF42wUvJMvq6osYJrs358dNx27drfwOl/nsczV4V9kU4PSSYpDlmVtY3+uNlH6eDgJIHM6MGNfbRj3jN9EqQl2vF6PZQUppb9YgqtncFlHa1Zm/mhujP3n3+kFbdwckdkm4HdDOPgh5NSK18xb+rxfKR3UVqUh71O1/SLRvDi4s0c1jt2gtBTV4yhMC+HovxcRg7oYq31Gs1k1ZMbTxjMkLIS7py1in11TfTvWsQRfUsZWtaRNTtr4mY6gfMMUxPzd8yx3Q+FeTn88aIREbEu0//dqTCXj247nuZAkBteWM7ctRVMPe2QuO/fWr51Im5y7lF9GVpWYvW+6+87nVqfn2cWbOSiMQMipvA/dOEIvty+nzEHdeXwPp3Iy/Fy8qFllohHF+Sxs2qrkS42tKyEa449mO+O7sd5oUV8f3XGoVTVN1mWRyKufGqJ4/Y6X7M1MaTW54+opGbPb5+7poKj7n2Pxf93YsT5N9um+0bn117x1GL21zfxg0kHU1HdyO4an7UOZVtwRMi/eCDZIukgJ0Fu5ZF9S7n3nMM5e3j8GiLpZNSALrzwyWYeOP9ILhozgNpGP8UFubwbWrKuPYh4944FlpWdCv27FnHphHIgMqOjR0kBPzlxiOM59k794UuOipnGcFD3YjbsriU/12uNgswMtCGhzJ+XfzCBb/bWWatTtRTTEo8uXPadoyLvBXPkdMm4gdbo+A/fHZ4wtnEgfGtF3OPxWKIAhgXVqTCPG06IvUmOGdKdY4ZEDnMSuR5+euIQ/hQKXpmYKWIjB3ShR0kBFdWNjD+4GwO7FUWI+PWTB/HXj9ZHnDu8Xykr4kz4qfM1W2tM1vuaqU5SxP8fCzdG/G1fE3J91IjCTMe87rnlMSlukHyFonQzWfXk3ZsnWQ9dW/G9sQP4dPM+xwVwPR6PJUDZ4LyRfenWMZ9JQ4xsEtPlNGFQN47sW8qtpwxNdHq7x+PxcNHo/jHPXyJ6OvjC3/rpsXHrFJnB4y7F+TEj1Zbws1MU26saOF71tIqyJeKmk8Jac6CfnYhvrYibLL/j5FadN6BrUdx9Z4/owzkjjIDWPa+vZuayLQy1+VHHHdSVN1Zup9HfTElhHq9eP4Hz/2rM+vv5aYfQr0sHfvmvVXQqzOXWUxSXTyyn3tfMoXe+nbBNdb7mpKsZPfJhZAdht9QSjQii/f5txdCy1NL9MklphzwevXQUyzbtzWhqWCp4PB4mq1jXUklhHq+HShK4nQcuGHbA7+GU5vjHi0Yw56tdCVMgW0J592Jevd6Y5PTpnfF15eKx/Vm4vjJtn5uMb72It6Q0rZ1epYX8+eKjyM/xct1zy8jP8TLtu8O4943VDOhaZPmL7zvvSC6dMJA+tqJBD5w/jDHlXRkZmmxiL1MKcHSoMNc/rhrLUaFjOqRQca/O54+pkrbxd2dw4oMfRUzjThduLKGbTqIDW4K7+M5RfWNcHekikUDff96Bd0ot4Vsv4gfC2cP7WLU+Du3TiXNG9OWcEZE3RV6O15pUYlJckMvlE8sjtj104XArn7W8e3GLfIgmdb7mmHrFEFmYx2TS0B5WWtaUcQN4/pPNsQclIdvuFEEQWo5rZ2xmi06FeTxx+eiIqfSt4byR/Th7eJ+Uj3/i8tERkybAEPHNDnUjzJTHUluK15UTy3nmqrEM6lEcd9FiQRDcj4h4Cpx4aJnjghSZYuZ1Ezjx0DJ+ffbhEds/+2Yfd/0nvHrKxWONEqC//c6RqLKSCJdOh/wcJg3twZxbJ0cEVK48ujziPR+yFfb50/dGoGw+abHDBaH9IyLeDjFLCZR3K45bz2Hpr07it98xakZMGtqDd26eZC25BfFXtbGvYP+n743gvJH9rL9Hl3flnZsn8dQV8cu7CoLQvhARb8d4vR4eunAEj1022lrhxaRrUX5MvmvA5sOOt6agufnGEwbH+PejC/yIS1wQ2j8S2GxHTBk3wHFh25MPK2PEgM4s3rCHT+84mV3VjY4TFuyLCJVGTYFe9IsTycvxWDnqnRxmETpVaRMEoX0jT2074rdRJTXtPPr9UXyzpy7hpIFRA7uwevt+nr9mXEThHjBSJgGGhOpGRy+bBuEp5Wb6ZLLlvQRBaHtaJeJKqcnATMCMsn0O3As8A3QGtgBTtNapVTASktK1OD9pzvsdZx7GRWP6R8xUjebC0f0Z3LMkYQ70xEHduOmkIVyWxZmJgiC0jgPxif9Xaz059O9GYBrwlNZ6PLARmJKOBgqpk5/rTSjgYEzgSTaJxes1akC3dqKUIAjZI52BzcnAf0KvZwGnpvG9hQwSXVtbEAT3cCA+8cOUUm8BJcCvgRKttbmkzS7Aeb0tod3x4W2TJRNFEFxKa0V8LfAb4CVgIPARYE+X8CBzRVxDQQbW0hQEITu0SsS11luBF0J/blBK7QD6KKWKtNZ1GFZ4/CLcgiAIQlpolU9cKfU9pdTdodfdgTLgceCc0CHnAbPT0UBBEAQhPq11p7wBXKSUmo/REfwI+BR4USl1C6CBl9PTREEQBCEerXWn1ADnOuyafECtEQRBEFqE1E4RBEFwMSLigiAILibbtVNyAHbsiF2AVhAEQYjFppeOucDZFvHeAFOmyIx8QRCEFtIbWB+9MdsivgQ4FtgOtO0y4oIgCO4gB0PAlzjt9MhiuIIgCO5FApuCIAguxhWLQiil7gVOAAqBH2qtl7Zxk1JCKXUERkXH6Vrrh5VSPXGoua6UOheYinF9M7TWT7ZZox1QSt0HHA/kAQ8A/8Wd11EEPI0xw7gYuAdYhAuvxUQp1QFYhVHP/01cdi0tWZugPV+HiVLqEuBWjPpRd2C4QDJ6Le3eEldKHQ+M1lofDVwOPNTGTUoJpVQxMAOYY9scU3NdKVUS2n4acDQwVSnVMcvNjYtSahIwQms9ATgFmI4LryPE2cBSrfVxwPnAH3DvtZj8CtgTeu3Wa0m6NoEbriPUnlsx2ncm8B2ycC3tXsQxLMBZAFrrVYQKbbVtk1KiETidyEJgk4mtuT4GQ1iqQsXD5mMEf9sLC4ALQ6+rgHyMUZHbrgOt9Uta69+H/uyHYRlNxoXXAqCUOgQ4lHCdosm49FqimIw7r+NUYLbWukFrvU1rfS1ZuBY3uFN6Aytsf1dgDIc3tE1zUkNr7Qf8Sin7Zqea670xromo7e2C0HXUhP68GmPIfrbbrsOOUuoTjLadDsxz8bX8AbgBuCL0t+vurxCprE3ghuvoD/QIXUtH4C6ycC1usMR9UX+7uVa5/VrM63DF9SmlzgGuBW7CxdcBoLUeh1H75yXAb9vlmmtRSl0GzNVab7RtduPvYq5NcAZwKUY1VKe1Cdr7dQAUYBiYZwJXYcRfMn5/uUHEtwM9bX/3AHa2UVsOlGqbK8isuR59fe2uFrtS6lTgTuA0rfU+3Hsdo5VSAwC01ssx7v9aN14LhuhdoJRaBFyDEUSrd9u1aK23aq1f0FoHtNYbgB1AR7ddR4gdwEKtdbPWei2wnyzcX25wp7wF/Bb4q1JqJPC1bXjiNt7GqLn+IuGa64uBYUqpUowJUOOA69ushVGE2vUQcILWujK02XXXEWIiUA7copQqwxi+z8KF16K1vsh8HartvxEYhcuuRSn1PeAQrfXdDmsTuOY6QrwPPK6U+gOGsVmC4Q/P6LW4YrKPUuoB4GSMocnVWuvP27hJSVFKjQIexBCNJmArMAV4HiO9TQNXaK39SqnvYmQZBIDfa61fbJNGO6CU+gFwN7DGtvly4B+46DoAlFIFwFMYvssCjBTDZRgPmKuuxY5NxN/BZdcSysp4FsMy9WKkF36Ky67DJPS8XELYv7+EDF+LK0RcEARBcMYNPnFBEAQhDiLigiAILkZEXBAEwcWIiAuCILgYEXFBEAQXIyIuCILgYkTEBUEQXIyIuCAIgov5/6bydiasyF4JAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----\n",
      " ihouyrhhe ndsydnh   aosbdnn o h w\n",
      "  o  eM nd oanz f  soz rU lrh  oW\n",
      "mk  d ao inel hc: s oN rtuhiu ho t  ln\n",
      "lim ldokloi ose nNso 'dfhoml h,l n  nay sahtsi rh   nomo,reioMra\n",
      "asomnliu deam k Mm  fkaa nac \n",
      "----\n",
      "iter 601, loss 97.386248\n"
     ]
    }
   ],
   "source": [
    "while True:\n",
    "    try:\n",
    "        with DelayedKeyboardInterrupt():\n",
    "            # Reset\n",
    "            if pointer + T_steps >= len(data) or iteration == 0:\n",
    "                g_h_prev = np.zeros((H_size, 1))\n",
    "                g_C_prev = np.zeros((H_size, 1))\n",
    "                pointer = 0\n",
    "\n",
    "\n",
    "            inputs = ([char_to_idx[ch] \n",
    "                       for ch in data[pointer: pointer + T_steps]])\n",
    "            targets = ([char_to_idx[ch] \n",
    "                        for ch in data[pointer + 1: pointer + T_steps + 1]])\n",
    "\n",
    "            loss, g_h_prev, g_C_prev = \\\n",
    "                forward_backward(inputs, targets, g_h_prev, g_C_prev)\n",
    "            smooth_loss = smooth_loss * 0.999 + loss * 0.001\n",
    "\n",
    "            # Print every hundred steps\n",
    "            if iteration % 100 == 0:\n",
    "                update_status(inputs, g_h_prev, g_C_prev)\n",
    "\n",
    "            update_paramters()\n",
    "\n",
    "            plot_iter = np.append(plot_iter, [iteration])\n",
    "            plot_loss = np.append(plot_loss, [loss])\n",
    "\n",
    "            pointer += T_steps\n",
    "            iteration += 1\n",
    "    except KeyboardInterrupt:\n",
    "        update_status(inputs, g_h_prev, g_C_prev)\n",
    "        break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Gradient Check\n",
    "\n",
    "Approximate the numerical gradients by changing parameters and running the model.\n",
    "\n",
    "Check if the approximated gradients are equal to the computed analytical gradients (by backpropagation).\n",
    "\n",
    "Try this on `num_checks` individual paramters picked randomly for each weight matrix and bias vector."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "from random import uniform"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Calculate numerical gradient"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_numerical_gradient(param, idx, delta, inputs, target, h_prev, C_prev):\n",
    "    old_val = param.v.flat[idx]\n",
    "    \n",
    "    # evaluate loss at [x + delta] and [x - delta]\n",
    "    param.v.flat[idx] = old_val + delta\n",
    "    loss_plus_delta, _, _ = forward_backward(inputs, targets,\n",
    "                                             h_prev, C_prev)\n",
    "    param.v.flat[idx] = old_val - delta\n",
    "    loss_mins_delta, _, _ = forward_backward(inputs, targets, \n",
    "                                             h_prev, C_prev)\n",
    "    \n",
    "    param.v.flat[idx] = old_val #reset\n",
    "\n",
    "    grad_numerical = (loss_plus_delta - loss_mins_delta) / (2 * delta)\n",
    "    # Clip numerical error because analytical gradient is clipped\n",
    "    [grad_numerical] = np.clip([grad_numerical], -1, 1) \n",
    "    \n",
    "    return grad_numerical"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Check gradient of each paramter matrix/vector at `num_checks` individual values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gradient_check(num_checks, delta, inputs, target, h_prev, C_prev):\n",
    "    global parameters\n",
    "    \n",
    "    # To calculate computed gradients\n",
    "    _, _, _ =  forward_backward(inputs, targets, h_prev, C_prev)\n",
    "    \n",
    "    \n",
    "    for param in parameters.all():\n",
    "        #Make a copy because this will get modified\n",
    "        d_copy = np.copy(param.d)\n",
    "\n",
    "        # Test num_checks times\n",
    "        for i in range(num_checks):\n",
    "            # Pick a random index\n",
    "            rnd_idx = int(uniform(0, param.v.size))\n",
    "            \n",
    "            grad_numerical = calc_numerical_gradient(param,\n",
    "                                                     rnd_idx,\n",
    "                                                     delta,\n",
    "                                                     inputs,\n",
    "                                                     target,\n",
    "                                                     h_prev, C_prev)\n",
    "            grad_analytical = d_copy.flat[rnd_idx]\n",
    "\n",
    "            err_sum = abs(grad_numerical + grad_analytical) + 1e-09\n",
    "            rel_error = abs(grad_analytical - grad_numerical) / err_sum\n",
    "            \n",
    "            # If relative error is greater than 1e-06\n",
    "            if rel_error > 1e-06:\n",
    "                print('%s (%e, %e) => %e'\n",
    "                      % (param.name, grad_numerical, grad_analytical, rel_error))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "W_f (0.000000e+00, 6.698919e-11) => 6.278338e-02\n",
      "W_C (0.000000e+00, -4.851823e-11) => 4.627314e-02\n",
      "W_C (4.973799e-09, 5.100140e-09) => 1.140885e-02\n",
      "W_o (-8.526513e-09, -9.187699e-09) => 3.533071e-02\n",
      "W_o (4.702443e-05, 4.702540e-05) => 1.036964e-05\n",
      "W_o (-2.624176e-05, -2.624390e-05) => 4.073094e-05\n",
      "W_o (9.947598e-09, 1.002907e-08) => 3.884056e-03\n",
      "W_o (7.176482e-08, 7.164410e-08) => 8.359478e-04\n",
      "W_o (-8.007817e-07, -8.009594e-07) => 1.109070e-04\n",
      "W_o (-1.236344e-07, -1.243928e-07) => 3.045352e-03\n",
      "W_o (1.218581e-06, 1.221003e-06) => 9.925255e-04\n",
      "W_v (0.000000e+00, 1.124962e-10) => 1.011206e-01\n",
      "b_f (0.000000e+00, 2.303990e-15) => 2.303985e-06\n",
      "b_i (0.000000e+00, 6.846087e-12) => 6.799537e-03\n",
      "b_C (0.000000e+00, 4.851823e-11) => 4.627314e-02\n",
      "b_o (-3.424816e-07, -3.423552e-07) => 1.842488e-04\n",
      "b_o (-6.268621e-05, -6.268663e-05) => 3.296235e-06\n",
      "b_o (6.002665e-06, 6.003044e-06) => 3.157823e-05\n",
      "b_o (-8.526513e-09, -9.187714e-09) => 3.533145e-02\n",
      "b_o (1.378240e-05, 1.378108e-05) => 4.778681e-05\n",
      "b_o (0.000000e+00, 6.845985e-12) => 6.799436e-03\n",
      "b_o (-5.556444e-07, -5.561016e-07) => 4.108290e-04\n",
      "b_o (-2.041389e-05, -2.041450e-05) => 1.480507e-05\n",
      "b_o (-2.016520e-06, -2.018793e-06) => 5.631819e-04\n",
      "b_o (-7.055689e-07, -7.070640e-07) => 1.057596e-03\n",
      "b_v (6.892265e-08, 6.929811e-08) => 2.696902e-03\n",
      "b_v (0.000000e+00, 2.730120e-10) => 2.144615e-01\n",
      "b_v (7.105427e-10, 4.006369e-10) => 1.467927e-01\n"
     ]
    }
   ],
   "source": [
    "gradient_check(10, 1e-5, inputs, targets, g_h_prev, g_C_prev)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
